-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Decouple convergence checking from SamplerReport
#6453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Decouple convergence checking from SamplerReport
#6453
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6453 +/- ##
==========================================
+ Coverage 94.78% 94.79% +0.01%
==========================================
Files 148 148
Lines 27678 27678
==========================================
+ Hits 26234 26238 +4
+ Misses 1444 1440 -4
|
c6dbdbb to
f419ed3
Compare
f419ed3 to
6c2f7f2
Compare
The goal was to uncouple sampling functions from `MultiTrace` and `SamplerReport`. Some calls to `SamplerReport._log_summary()` were unnecessary because `MultiTrace._add_warnings()` was never called inbetween instantiation and `_log_summary()`, therefore the traces never contained warnings. Running convergence checks and logging the warnings can also be done without needing `MultiTrace` or `SamplerReport` instances/methods.
6c2f7f2 to
49f5263
Compare
* Specify covariant input types in `StatsBijection`. * Annotate `_choose_chains` to be independent of `BaseTrace` type.
|
I don't think I am qualified to review this |
I should have added comments to the diff earlier.. GitHub suggested you because you edited the SMC code? Who else is familiar with it? |
| S = TypeVar("S", bound=Sized) | ||
|
|
||
|
|
||
| def _choose_chains(traces: Sequence[S], tune: int) -> Tuple[List[S], int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This annotates it as returning a list of the same type of items as given in the input, but with the constraint that these items must be Sized.
| f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) " | ||
| f"took {t_sampling:.0f} seconds." | ||
| ) | ||
| mtrace.report._log_summary() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inbetween the line 574 mtrace = MultiTrace(traces)[:length] where the MultiTrace was created, no warnings were added to mtrace.
Therefore, there are no warnings to log and the _log_summary() call can safely be removed.
| warnings.warn( | ||
| "The number of samples is too small to check convergence reliably.", | ||
| stacklevel=2, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is now checked by run_convergence_checks, just like it already checked for a minimum number of chains
| multitrace = MultiTrace(traces) | ||
| multitrace._report._log_summary() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here too: The multitrace can not have warnings that would be printed by _log_summary() because none were added here or in its __init__
| if idata is None: | ||
| idata = to_inference_data(trace, log_likelihood=False) | ||
| warns = run_convergence_checks(idata, model) | ||
| trace.report._add_warnings(warns) | ||
| log_warnings(warns) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This replaces the _compute_convergence_checks function and makes the trace.report be a dead end that can easily be removed in the future
Remember from other changes:
- "number of samples is too small" warning now done by
run_convergence_checks report._add_warningswas done insidereport._run_convergence_checkstrace.report._log_summary()internally calledlog_warnings()
| """Map between a `list` of stats to `dict` of stats.""" | ||
|
|
||
| def __init__(self, sampler_stats_dtypes: Sequence[Dict[str, type]]) -> None: | ||
| def __init__(self, sampler_stats_dtypes: Sequence[Mapping[str, type]]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typing rule of thumb: Generic input types, exact output types.
I have only modified some docstrings 😅. @aloctavodia is the best choice I think |
aloctavodia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
The goal was to uncouple sampling functions from
MultiTraceandSamplerReport.Some calls to
SamplerReport._log_summary()were unnecessary becauseMultiTrace._add_warnings()was never called inbetween instantiation and_log_summary(), therefore the traces never contained warnings.Running convergence checks and logging the warnings can also be done without needing
MultiTraceorSamplerReportinstances/methods.Checklist
Minor changes
"The number of samples is too small to check convergence reliably."warning is now anINFOlevel log message instead of aWarning.SamplerReport._log_summary()andSamplerReport._run_convergence_checksmethods were removed.Maintenance
MultiTraceorSamplerReportto compute/log warnings.